#include "yolo_v2_class.hpp"

#include "network.h"

extern "C" {
#include "detection_layer.h"
#include "region_layer.h"
#include "cost_layer.h"
#include "utils.h"
#include "parser.h"
#include "box.h"
#include "image.h"
#include "demo.h"
#include "option_list.h"
#include "stb_image.h"
}
//#include <sys/time.h>

#include <vector>
#include <iostream>
#include <algorithm>
#include <thread>
#include <future>
#include <functional>

#define FRAMES 3

#ifdef GPU
void check_cuda(cudaError_t status) {
	if (status != cudaSuccess) {
		const char *s = cudaGetErrorString(status);
		printf("CUDA Error Prev: %s\n", s);
	}
}
#endif
void hierarchy_predictions(float *predictions, int n, tree *hier, int only_leaves)
{
    int j;
    for(j = 0; j < n; ++j){
        int parent = hier->parent[j];
        if(parent >= 0){
            predictions[j] *= predictions[parent]; 
        }
    }
    if(only_leaves){
        for(j = 0; j < n; ++j){
            if(!hier->leaf[j]) predictions[j] = 0;
        }
    }
}
struct detector_gpu_t {
	network net;
	image images[FRAMES];
	float *avg;
	float *predictions[FRAMES];
	int demo_index;
	unsigned int *track_id;
};

YOLODLL_API Detector::Detector(std::string cfg_filename, std::string weight_filename, int gpu_id) : cur_gpu_id(gpu_id)
{
	wait_stream = 0;
	int old_gpu_index;
#ifdef GPU
	check_cuda( cudaGetDevice(&old_gpu_index) );
#endif

	detector_gpu_ptr = std::make_shared<detector_gpu_t>();
	detector_gpu_t &detector_gpu = *static_cast<detector_gpu_t *>(detector_gpu_ptr.get());

#ifdef GPU
	//check_cuda( cudaSetDevice(cur_gpu_id) );
	cuda_set_device(cur_gpu_id);
	printf(" Used GPU %d \n", cur_gpu_id);
#endif
	network &net = detector_gpu.net;
	net.gpu_index = cur_gpu_id;
	//gpu_index = i;
	
	char *cfgfile = const_cast<char *>(cfg_filename.data());
	char *weightfile = const_cast<char *>(weight_filename.data());

	net = parse_network_cfg_custom(cfgfile, 1);
	if (weightfile) {
		load_weights(&net, weightfile);
	}
	set_batch_network(&net, 1);
	net.gpu_index = cur_gpu_id;
	fuse_conv_batchnorm(net);

	layer l = net.layers[net.n - 1];
	int j;

	detector_gpu.avg = (float *)calloc(l.outputs, sizeof(float));
	for (j = 0; j < FRAMES; ++j) detector_gpu.predictions[j] = (float *)calloc(l.outputs, sizeof(float));
	for (j = 0; j < FRAMES; ++j) detector_gpu.images[j] = make_image(1, 1, 3);

	detector_gpu.track_id = (unsigned int *)calloc(l.classes, sizeof(unsigned int));
	for (j = 0; j < l.classes; ++j) detector_gpu.track_id[j] = 1;

#ifdef GPU
	check_cuda( cudaSetDevice(old_gpu_index) );
#endif
}


YOLODLL_API Detector::~Detector() 
{
	detector_gpu_t &detector_gpu = *static_cast<detector_gpu_t *>(detector_gpu_ptr.get());
	layer l = detector_gpu.net.layers[detector_gpu.net.n - 1];

	free(detector_gpu.track_id);

	free(detector_gpu.avg);
	for (int j = 0; j < FRAMES; ++j) free(detector_gpu.predictions[j]);
	for (int j = 0; j < FRAMES; ++j) if(detector_gpu.images[j].data) free(detector_gpu.images[j].data);

	int old_gpu_index;
#ifdef GPU
	cudaGetDevice(&old_gpu_index);
	cuda_set_device(detector_gpu.net.gpu_index);
#endif

	free_network(detector_gpu.net);

#ifdef GPU
	cudaSetDevice(old_gpu_index);
#endif
}

YOLODLL_API int Detector::get_net_width() const {
	detector_gpu_t &detector_gpu = *static_cast<detector_gpu_t *>(detector_gpu_ptr.get());
	return detector_gpu.net.w;
}
YOLODLL_API int Detector::get_net_height() const {
	detector_gpu_t &detector_gpu = *static_cast<detector_gpu_t *>(detector_gpu_ptr.get());
	return detector_gpu.net.h;
}


YOLODLL_API std::vector<bbox_t> Detector::detect(std::string image_filename, float thresh, bool use_mean)
{
	std::shared_ptr<image_t> image_ptr(new image_t, [](image_t *img) { if (img->data) free(img->data); delete img; });
	*image_ptr = load_image(image_filename);
	return detect(*image_ptr, thresh, use_mean);
}

static image load_image_stb(char *filename, int channels)
{
	int w, h, c;
	unsigned char *data = stbi_load(filename, &w, &h, &c, channels);
	if (!data) 
		throw std::runtime_error("file not found");
	if (channels) c = channels;
	int i, j, k;
	image im = make_image(w, h, c);
	for (k = 0; k < c; ++k) {
		for (j = 0; j < h; ++j) {
			for (i = 0; i < w; ++i) {
				int dst_index = i + w*j + w*h*k;
				int src_index = k + c*i + c*w*j;
				im.data[dst_index] = (float)data[src_index] / 255.;
			}
		}
	}
	free(data);
	return im;
}

YOLODLL_API image_t Detector::load_image(std::string image_filename)
{
	char *input = const_cast<char *>(image_filename.data());
	image im = load_image_stb(input, 3);

	image_t img;
	img.c = im.c;
	img.data = im.data;
	img.h = im.h;
	img.w = im.w;

	return img;
}


YOLODLL_API void Detector::free_image(image_t m)
{
	if (m.data) {
		free(m.data);
	}
}

YOLODLL_API std::vector<bbox_t> Detector::detect(image_t img, float thresh, bool use_mean)
{
	detector_gpu_t &detector_gpu = *static_cast<detector_gpu_t *>(detector_gpu_ptr.get());
	network &net = detector_gpu.net;
	int old_gpu_index;
#ifdef GPU
	cudaGetDevice(&old_gpu_index);
	if(cur_gpu_id != old_gpu_index)
		cudaSetDevice(net.gpu_index);

	net.wait_stream = wait_stream;	// 1 - wait CUDA-stream, 0 - not to wait
#endif
	//std::cout << "net.gpu_index = " << net.gpu_index << std::endl;

	//float nms = .4;

	image im;
	im.c = img.c;
	im.data = img.data;
	im.h = img.h;
	im.w = img.w;

	image sized;
	
	if (net.w == im.w && net.h == im.h) {
		sized = make_image(im.w, im.h, im.c);
		memcpy(sized.data, im.data, im.w*im.h*im.c * sizeof(float));
	}
	else
		sized = resize_image(im, net.w, net.h);

	layer l = net.layers[net.n - 1];

	float *X = sized.data;

	float *prediction = network_predict(net, X);

	if (use_mean) {
		memcpy(detector_gpu.predictions[detector_gpu.demo_index], prediction, l.outputs * sizeof(float));
		mean_arrays(detector_gpu.predictions, FRAMES, l.outputs, detector_gpu.avg);
		l.output = detector_gpu.avg;
		detector_gpu.demo_index = (detector_gpu.demo_index + 1) % FRAMES;
	}
	//get_region_boxes(l, 1, 1, thresh, detector_gpu.probs, detector_gpu.boxes, 0, 0);
	//if (nms) do_nms_sort(detector_gpu.boxes, detector_gpu.probs, l.w*l.h*l.n, l.classes, nms);

	int nboxes = 0;
	int letterbox = 0;
	float hier_thresh = 0.5;
	detection *dets = get_network_boxes(&net, im.w, im.h, thresh, hier_thresh, 0, 1, &nboxes, letterbox);
	if (nms) do_nms_sort(dets, nboxes, l.classes, nms);

	std::vector<bbox_t> bbox_vec;

	for (size_t i = 0; i < nboxes; ++i) {
		box b = dets[i].bbox;
		int const obj_id = max_index(dets[i].prob, l.classes);
		float const prob = dets[i].prob[obj_id];
		
		if (prob > thresh) 
		{
			bbox_t bbox;
			bbox.x = std::max((double)0, (b.x - b.w / 2.)*im.w);
			bbox.y = std::max((double)0, (b.y - b.h / 2.)*im.h);
			bbox.w = b.w*im.w;
			bbox.h = b.h*im.h;
			bbox.obj_id = obj_id;
			bbox.prob = prob;
			bbox.track_id = 0;

			bbox_vec.push_back(bbox);
		}
	}

	free_detections(dets, nboxes);
	if(sized.data)
		free(sized.data);

#ifdef GPU
	if (cur_gpu_id != old_gpu_index)
		cudaSetDevice(old_gpu_index);
#endif

	return bbox_vec;
}

YOLODLL_API std::vector<bbox_t> Detector::tracking_id(std::vector<bbox_t> cur_bbox_vec, bool const change_history, 
	int const frames_story, int const max_dist)
{
	detector_gpu_t &det_gpu = *static_cast<detector_gpu_t *>(detector_gpu_ptr.get());

	bool prev_track_id_present = false;
	for (auto &i : prev_bbox_vec_deque)
		if (i.size() > 0) prev_track_id_present = true;

	if (!prev_track_id_present) {
		for (size_t i = 0; i < cur_bbox_vec.size(); ++i)
			cur_bbox_vec[i].track_id = det_gpu.track_id[cur_bbox_vec[i].obj_id]++;
		prev_bbox_vec_deque.push_front(cur_bbox_vec);
		if (prev_bbox_vec_deque.size() > frames_story) prev_bbox_vec_deque.pop_back();
		return cur_bbox_vec;
	}

	std::vector<unsigned int> dist_vec(cur_bbox_vec.size(), std::numeric_limits<unsigned int>::max());

	for (auto &prev_bbox_vec : prev_bbox_vec_deque) {
		for (auto &i : prev_bbox_vec) {
			int cur_index = -1;
			for (size_t m = 0; m < cur_bbox_vec.size(); ++m) {
				bbox_t const& k = cur_bbox_vec[m];
				if (i.obj_id == k.obj_id) {
					float center_x_diff = (float)(i.x + i.w/2) - (float)(k.x + k.w/2);
					float center_y_diff = (float)(i.y + i.h/2) - (float)(k.y + k.h/2);
					unsigned int cur_dist = sqrt(center_x_diff*center_x_diff + center_y_diff*center_y_diff);
					if (cur_dist < max_dist && (k.track_id == 0 || dist_vec[m] > cur_dist)) {
						dist_vec[m] = cur_dist;
						cur_index = m;
					}
				}
			}

			bool track_id_absent = !std::any_of(cur_bbox_vec.begin(), cur_bbox_vec.end(), 
				[&i](bbox_t const& b) { return b.track_id == i.track_id && b.obj_id == i.obj_id; });

			if (cur_index >= 0 && track_id_absent){
				cur_bbox_vec[cur_index].track_id = i.track_id;
				cur_bbox_vec[cur_index].w = (cur_bbox_vec[cur_index].w + i.w) / 2;
				cur_bbox_vec[cur_index].h = (cur_bbox_vec[cur_index].h + i.h) / 2;
			}
		}
	}

	for (size_t i = 0; i < cur_bbox_vec.size(); ++i)
		if (cur_bbox_vec[i].track_id == 0)
			cur_bbox_vec[i].track_id = det_gpu.track_id[cur_bbox_vec[i].obj_id]++;

	if (change_history) {
		prev_bbox_vec_deque.push_front(cur_bbox_vec);
		if (prev_bbox_vec_deque.size() > frames_story) prev_bbox_vec_deque.pop_back();
	}

	return cur_bbox_vec;
}
#ifdef OPENCV
static image_t ipl_to_image(IplImage* src)
{
	unsigned char *data = (unsigned char *)src->imageData;
	int h = src->height;
		int w = src->width;
		int c = src->nChannels;
		int step = src->widthStep;
		image out = make_image(w, h, c);

		int count = 0;

		for (int k = 0; k < c; ++k) {
			for (int i = 0; i < h; ++i) {
				int i_step = i*step;
				for (int j = 0; j < w; ++j) {
					out.data[count++] = data[i_step + j*c + k] / 255.;
				}
			}
		}
                image_t img;
                img.w = out.w;
                img.h = out.h;
                img.c = out.c;
                img.data = out.data;

		return img;
}

#endif
YOLODLL_API Classifier::Classifier():_Inited(false)
{
}

YOLODLL_API Classifier::~Classifier()
{
    if (_Inited)
    {
	for (auto& i:nets)
	{
	    delete i;
	}
    }
}
YOLODLL_API void Classifier::Init(std::string cfg_filename, std::string weight_filename, int nnet )
{
    if (_Inited) return;
    net_set_size = nnet;
    char *cfgfile = const_cast<char *>(cfg_filename.data());
    char *weightfile = const_cast<char *>(weight_filename.data());
    for (int i = 0;i<nnet;i++)
    {
	network* net = new network;
	*net = parse_network_cfg(cfgfile);
	if (weightfile) load_weights(net,weightfile);
	set_batch_network(net,1);
        nets.push_back(net);
    }
    _Inited = true;

}

#ifdef OPENCV
YOLODLL_API std::vector<int> Classifier::classify(const std::vector<cv::Mat>& mats)
{
    //1 trans Mat to image;
    //2 detect them
    //3 return the result
    std::vector<int> rets;
    if (mats.size() > nets.size()) throw std::runtime_error("bad argument num\n");
    std::vector<std::future<int>> futures;
    int i = 0;

    for (const auto& mat: mats)
    {

        network* thisnet = nets[i];	
        std::packaged_task<int(network*,const cv::Mat&)> task([this](network* net,const cv::Mat& mat){return predictAnImage(net,mat);});
        std::future<int> f = task.get_future();
        std::thread {std::move(task),thisnet,mat}.detach();
	
	//task_td.detach();
	futures.push_back(std::move(f));
	i++;
    }

    for (auto iter = futures.begin();iter !=futures.end();iter++)
    {

        std::future<int>& f = *iter;
        try{
	rets.push_back(f.get());
        //std::cout << "future got\n" ;
        }
        catch(std::future_error& err)
        {
            std::cerr << "future_error\n" << err.what();
        }
        catch(...)
        {
            std::cerr << "other_error\n";
        }
    }

    return rets;
}

YOLODLL_API std::vector<float*> Classifier::classifyAndReturnLayers(const std::vector<cv::Mat>& mats)
{
    std::vector<float*> rets;
     if (mats.size() > nets.size()) throw std::runtime_error("bad argument num\n");
    std::vector<std::future<float*>> futures;
    int i = 0;

    for (const auto& mat: mats)
    {

        network* thisnet = nets[i];	
        std::packaged_task<float*(network*,const cv::Mat&)> task([this](network* net,const cv::Mat& mat){return predictImageAndReturnBack(net,mat);});
        std::future<float*> f = task.get_future();
        std::thread {std::move(task),thisnet,mat}.detach();
	
	//task_td.detach();
	futures.push_back(std::move(f));
	i++;
    }

    for (auto iter = futures.begin();iter !=futures.end();iter++)
    {
	

        std::future<float*>& f = *iter;
        try{
	    rets.push_back(f.get());
            //std::cout << "future got\n" ;
        }
        catch(std::future_error& err)
        {
            std::cerr << "future_error\n" << err.what();
        }
        catch(...)
        {
            std::cerr << "other_error\n";
        }

    }
    return rets;

   
}
int Classifier::predictAnImage(network* net,const cv::Mat& mat)
{

    cv::Mat imgMat;
    if (mat.channels() == 1)
    {
       cv::cvtColor(mat, imgMat, cv::COLOR_GRAY2BGR);
    }
    else imgMat = mat;


    
    
    std::shared_ptr<IplImage> ptr = std::make_shared<IplImage>(imgMat);
    image_t im = ipl_to_image(ptr.get());
    int result =  predictAnImage(net,im);

    if (im.data) free(im.data);

    return result;
}
float* Classifier::predictImageAndReturnBack(network* net,const cv::Mat& mat)
{
 
    cv::Mat imgMat;
    if (mat.channels() == 1)
    {
       cv::cvtColor(mat, imgMat, cv::COLOR_GRAY2BGR);
    }
    else imgMat = mat;


    
    
    std::shared_ptr<IplImage> ptr = std::make_shared<IplImage>(imgMat);
    image_t im = ipl_to_image(ptr.get());
    float* result =  predictImageAndReturnBack(net,im);

    if (im.data) free(im.data);

    return result;
                              

}
#endif
int Classifier::predictAnImage(network* net,image_t img)
{
    
    image im;
    im.w = img.w;
    im.h = img.h;
    im.c = img.c;

    im.data = img.data;

    image r = letterbox_image(im,net->w,net->h);

    float* X = r.data;
    
    float *predictions = network_predict(*net, X);

    if(net->hierarchy) hierarchy_predictions(predictions, net->outputs, net->hierarchy, 0);
    int top = 1;
    int index;
    top_k(predictions, net->outputs, top, &index);
    if (r.data != im.data) free_image(r);


    return index+1;
}
float*  Classifier::predictImageAndReturnBack(network* net,image_t img)
{
 
    image im;
    im.w = img.w;
    im.h = img.h;
    im.c = img.c;

    im.data = img.data;

    image r = letterbox_image(im,net->w,net->h);

    float* X = r.data;
    
    float *predictions = network_predict(*net, X);

    if(net->hierarchy) hierarchy_predictions(predictions, net->outputs, net->hierarchy, 0);
    return predictions;
}

